# (C) Crown Copyright, Met Office. All rights reserved.
#
# This file is part of 'IMPROVER' and is released under the BSD 3-Clause license.
# See LICENSE in the root of the repository for full licensing details.
"""Module for generating the weights for topographic zones."""

import warnings
from typing import Dict, List, Optional

import iris
import numpy as np
from cf_units import Unit
from iris.cube import Cube
from iris.exceptions import InvalidCubeError
from numpy import ndarray

from improver import BasePlugin
from improver.generate_ancillaries.generate_ancillary import (
    GenerateOrographyBandAncils,
    _make_mask_cube,
)


class GenerateTopographicZoneWeights(BasePlugin):
    """Generate weights generated by determining where the orography lies
    within the topographic zones."""

    @staticmethod
    def add_weight_to_upper_adjacent_band(
        topographic_zone_weights: ndarray,
        orography_band: ndarray,
        midpoint: float,
        band_number: float,
        max_band_number: float,
    ) -> ndarray:
        """Once we have found the weight for a point in one band,
        we need to add 1-weight to the band above for points that are above
        the midpoint, unless the band being processed is the uppermost band.

        Args:
            topographic_zone_weights:
                Weights that we have already calculated for the points
                within the orography band.
            orography_band:
                All points within the orography band of interest.
            midpoint:
                The midpoint of the band the point is in.
            band_number:
                The index that corresponds to the band that is currently being
                processed.
            max_band_number:
                The highest index for the bands coordinate in the weights.

        Returns:
            Weights that we have already calculated for the points within
            the orography band that has been updated to account for the
            upper adjacent band.
        """
        weights = topographic_zone_weights[band_number]

        # For points above the midpoint.
        with np.errstate(invalid="ignore"):
            mask_y, mask_x = np.where(orography_band > midpoint)
        if band_number == max_band_number:
            adjacent_band_number = band_number
            topographic_zone_weights[adjacent_band_number, mask_y, mask_x] = 1.0
        else:
            adjacent_band_number = band_number + 1
            topographic_zone_weights[adjacent_band_number, mask_y, mask_x] = (
                1 - weights[mask_y, mask_x]
            )
        return topographic_zone_weights

    @staticmethod
    def add_weight_to_lower_adjacent_band(
        topographic_zone_weights: ndarray,
        orography_band: ndarray,
        midpoint: float,
        band_number: float,
    ) -> ndarray:
        """Once we have found the weight for a point in one band,
        we need to add 1-weight to the band below for points that are below
        the midpoint, unless the band being processed is the lowest band.

        Args:
            topographic_zone_weights:
                Weights that we have already calculated for the points
                within the orography band.
            orography_band:
                All points within the orography band of interest.
            midpoint:
                The midpoint of the band the point is in.
            band_number:
                The index that corresponds to the band that is currently being
                processed.

        Returns:
            Topographic zone array containing the weights that we have
            already calculated for the points within the orography band
            that has been updated to account for the lower adjacent band.
        """
        weights = topographic_zone_weights[band_number]

        # For points below the midpoint.
        with np.errstate(invalid="ignore"):
            mask_y, mask_x = np.where(orography_band < midpoint)
        if band_number == 0:
            adjacent_band_number = band_number
            topographic_zone_weights[adjacent_band_number, mask_y, mask_x] = 1.0
        else:
            adjacent_band_number = band_number - 1
            topographic_zone_weights[adjacent_band_number, mask_y, mask_x] = (
                1 - weights[mask_y, mask_x]
            )
        return topographic_zone_weights

    @staticmethod
    def calculate_weights(points: ndarray, band: List[float]) -> ndarray:
        """Calculate weights where the weight at the midpoint of a band is 1.0
        and the weights at the edge of the band is 0.5. The midpoint is
        assumed to be in the middle of the band.

        Args:
            points:
                The points at which to find the weights.
                e.g. np.array([125]) or np.array([125, 140]).
            band:
                The band to be used for determining the weight that the
                selected points should have within the band
                e.g. [100., 200.].

        Returns:
            The weights generated to indicate the contribution of each
            point to a band.
        """
        weights = np.array([0.5, 1.0, 0.5], np.float32)
        midpoint = np.mean(band)
        band_points = np.array([band[0], midpoint, band[1]], np.float32)
        interpolated_weights = np.interp(points, band_points, weights).astype(
            np.float32
        )
        return interpolated_weights

    def process(
        self,
        orography: Cube,
        thresholds_dict: Dict[str, List[float]],
        landmask: Optional[Cube] = None,
    ) -> Cube:
        """Calculate the weights depending upon where the orography point is
        within the topographic zones.

        Args:
            orography:
                Orography on standard grid.
            thresholds_dict:
                Definition of orography bands required.
                The expected format of the dictionary is e.g.
                `{'bounds': [[0, 50], [50, 200]], 'units': 'm'}`
            landmask:
                Land mask on standard grid, with land points set to one and
                sea points set to zero. If provided sea points are masked
                out in the output array.

        Returns:
            Cube containing the weights depending upon where the orography
            point is within the topographic zones.
        """
        # Check that orography is a 2d cube.
        if len(orography.shape) != 2:
            msg = (
                "The input orography cube should be two-dimensional."
                "The input orography cube has {} dimensions".format(
                    len(orography.shape)
                )
            )
            raise InvalidCubeError(msg)

        # Find bands and midpoints from bounds.
        bands = np.array(thresholds_dict["bounds"], dtype=np.float32)
        threshold_units = thresholds_dict["units"]

        # Create topographic_zone_cube first, so that a cube is created for
        # each band. This will allow the data for neighbouring bands to be
        # put into the cube.
        mask_data = np.zeros(orography.shape, dtype=np.float32)
        topographic_zone_cubes = iris.cube.CubeList([])
        for band in bands:
            sea_points_included = not landmask
            topographic_zone_cube = _make_mask_cube(
                mask_data,
                orography.coords(),
                band,
                threshold_units,
                sea_points_included=sea_points_included,
            )
            topographic_zone_cubes.append(topographic_zone_cube)
        topographic_zone_weights = topographic_zone_cubes.concatenate_cube()
        topographic_zone_weights.data = topographic_zone_weights.data.astype(np.float32)

        # Ensure topographic_zone coordinate units is equal to orography units.
        topographic_zone_weights.coord("topographic_zone").convert_units(
            orography.units
        )

        # Read bands from cube, now that they can be guaranteed to be in the
        # same units as the orography. The bands are converted to a list, so
        # that they can be iterated through.
        bands = list(topographic_zone_weights.coord("topographic_zone").bounds)
        midpoints = topographic_zone_weights.coord("topographic_zone").points

        # Raise a warning, if orography extremes are outside the extremes of
        # the bands.
        if np.max(orography.data) > np.max(bands):
            msg = (
                "The maximum orography is greater than the uppermost band. "
                "This will potentially cause the topographic zone weights "
                "to not sum to 1 for a given grid point."
            )
            warnings.warn(msg)

        if np.min(orography.data) < np.min(bands):
            msg = (
                "The minimum orography is lower than the lowest band. "
                "This will potentially cause the topographic zone weights "
                "to not sum to 1 for a given grid point."
            )
            warnings.warn(msg)

        # Insert the appropriate weights into the topographic zone cube. This
        # includes the weights from the band that a point is in, as well as
        # the contribution from an adjacent band.
        for band_number, band in enumerate(bands):
            # Determine the points that are within the specified band.
            mask_y, mask_x = np.where(
                (orography.data > band[0]) & (orography.data <= band[1])
            )
            orography_band = np.full(orography.shape, np.nan, dtype=np.float32)
            orography_band[mask_y, mask_x] = orography.data[mask_y, mask_x]

            # Calculate the weights. This involves calculating the
            # weights for all the orography but only inserting weights
            # that are within the band into the topographic_zone_weights cube.
            weights = self.calculate_weights(orography_band, band)
            topographic_zone_weights.data[band_number, mask_y, mask_x] = weights[
                mask_y, mask_x
            ]

            # Calculate the contribution to the weights from the adjacent
            # lower band.
            topographic_zone_weights.data = self.add_weight_to_lower_adjacent_band(
                topographic_zone_weights.data,
                orography_band,
                midpoints[band_number],
                band_number,
            )

            # Calculate the contribution to the weights from the adjacent
            # upper band.
            topographic_zone_weights.data = self.add_weight_to_upper_adjacent_band(
                topographic_zone_weights.data,
                orography_band,
                midpoints[band_number],
                band_number,
                len(bands) - 1,
            )

        # Metadata updates
        topographic_zone_weights.rename("topographic_zone_weights")
        topographic_zone_weights.units = Unit("1")

        # Mask output weights using a land-sea mask.
        topographic_zone_masked_weights = iris.cube.CubeList([])
        for topographic_zone_slice in topographic_zone_weights.slices_over(
            "topographic_zone"
        ):
            if landmask:
                topographic_zone_slice.data = GenerateOrographyBandAncils().sea_mask(
                    landmask.data, topographic_zone_slice.data
                )
            topographic_zone_masked_weights.append(topographic_zone_slice)
        topographic_zone_weights = topographic_zone_masked_weights.merge_cube()
        return topographic_zone_weights
